import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from Resnet_attention import ResNet18SelfAttentionPredictor


def construct_input_tensor(n_qubits, n_layers, variant=0, num_variants=20, device="cpu"):
    layer_grid = np.linspace(0, 1, n_layers).reshape(n_layers, 1)
    col_grid = np.linspace(0, 1, 2 * n_qubits).reshape(1, 2 * n_qubits)
    layer_coords = np.repeat(layer_grid, 2 * n_qubits, axis=1)
    col_coords = np.repeat(col_grid, n_layers, axis=0)
    coord_input = np.stack([layer_coords, col_coords], axis=0).astype(np.float32)
    variant_norm = variant / (num_variants - 1) if num_variants > 1 else 0.0
    variant_channel = np.full((n_layers, 2 * n_qubits), variant_norm, dtype=np.float32)
    input_template = np.concatenate([coord_input, variant_channel[np.newaxis, ...]], axis=0)
    input_tensor = torch.tensor(input_template).unsqueeze(0).to(device)
    return input_tensor


def predict_intensity(model, n_qubits, n_layers, variant=0, device="cpu"):
    model.eval()
    input_tensor = construct_input_tensor(n_qubits, n_layers, variant=variant, num_variants=20, device=device)
    with torch.no_grad():
        pred = model(input_tensor)
    intensity = pred.squeeze(0).squeeze(0).cpu().numpy()
    min_val = intensity.min()
    max_val = intensity.max()
    if max_val - min_val > 1e-6:
        intensity = (intensity - min_val) / (max_val - min_val) * 100
    else:
        intensity = np.zeros_like(intensity)
    return intensity


def draw_intensity_circuit(intensity, n_qubits, n_layers, title="Predicted Circuit Intensity", vmin=0, vmax=None, show_cz=True, save_filename=None):
    intensity_flat = intensity.flatten()
    if intensity_flat.shape[0] != 2 * n_qubits * n_layers:
        raise ValueError("Intensity array size does not match expected dimensions.")
    if vmax is None:
        vmax = float(intensity_flat.max())
    cmap = plt.get_cmap("Blues")
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    fig, ax = plt.subplots(figsize=(20, 8))
    fig.subplots_adjust(right=0.8)
    for q in range(n_qubits):
        ax.plot([0, n_layers], [q, q], color="black", linewidth=2, alpha=0.5)
    if show_cz:
        for layer_idx in range(n_layers):
            for qubit_idx in range(n_qubits):
                y1 = qubit_idx
                y2 = (qubit_idx + 1) % n_qubits
                ax.plot([layer_idx, layer_idx], [y1, y2], color="black", linewidth=2, alpha=1)
    for layer_idx in range(n_layers):
        rx_vals = intensity[layer_idx, :n_qubits]
        ry_vals = intensity[layer_idx, n_qubits: 2 * n_qubits]
        x_rx = np.full(n_qubits, layer_idx + 0.25)
        y_rx = np.arange(n_qubits)
        x_ry = np.full(n_qubits, layer_idx + 0.75)
        y_ry = np.arange(n_qubits)
        colors_rx = [cmap(norm(val)) for val in rx_vals]
        colors_ry = [cmap(norm(val)) for val in ry_vals]
        ax.scatter(x_rx, y_rx, c=colors_rx, s=300, marker="o", edgecolors="k", linewidths=2)
        ax.scatter(x_ry, y_ry, c=colors_ry, s=300, marker="o", edgecolors="k", linewidths=2)
    ax.set_xlim(-0.5, n_layers + 0.5)
    ax.set_ylim(-0.5, n_qubits - 0.5)
    ax.set_xlabel("Layer Index", fontsize=18)
    ax.set_ylabel("Qubit Index", fontsize=18)
    ax.set_xticks(range(n_layers))
    ax.set_yticks(range(n_qubits))
    ax.grid(False)
    for spine in ax.spines.values():
        spine.set_linewidth(3)
    ax.tick_params(axis="both", labelsize=18)
    sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])
    cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7])
    cbar = plt.colorbar(sm, cax=cbar_ax)
    cbar.set_label("Intensity", fontsize=18)
    if save_filename is not None:
        plt.savefig(save_filename)
        plt.close(fig)
    else:
        plt.show()


def draw_high_count_circuit(positions, n_qubits, n_layers, title="High Count Nodes (Counts > 80)", show_cz=True, save_filename=None):
    fig, ax = plt.subplots(figsize=(20, 8))
    fig.subplots_adjust(right=0.8)
    for q in range(n_qubits):
        ax.plot([0, n_layers], [q, q], color="black", linewidth=2, alpha=0.5)
    if show_cz:
        for layer in range(n_layers):
            for q in range(n_qubits):
                y1 = q
                y2 = (q + 1) % n_qubits
                ax.plot([layer, layer], [y1, y2], color="black", linewidth=2, alpha=1)
    for (x, y, gate_type, count) in positions:
        ax.scatter(x, y, color="red", s=500, marker="x", linewidths=8)
    ax.set_xlim(-0.5, n_layers + 0.5)
    ax.set_ylim(-0.5, n_qubits - 0.5)
    ax.set_xlabel("Layer Index", fontsize=18)
    ax.set_ylabel("Qubit Index", fontsize=18)
    for spine in ax.spines.values():
        spine.set_linewidth(3)
    ax.tick_params(axis="both", labelsize=18)
    ax.set_xticks(range(n_layers))
    ax.set_yticks(range(n_qubits))
    ax.grid(False)
    if save_filename is not None:
        plt.savefig(save_filename)
        plt.close(fig)
    else:
        plt.show()


def get_high_count_positions(intensity, threshold=80, n_qubits=None):
    n_layers = intensity.shape[0]
    if n_qubits is None:
        n_qubits = intensity.shape[1] // 2
    positions = []
    for layer in range(n_layers):
        for q in range(n_qubits):
            rx_count = intensity[layer, q]
            if rx_count > threshold:
                positions.append((layer + 0.25, q, "RX", rx_count))
        for q in range(n_qubits):
            ry_count = intensity[layer, n_qubits + q]
            if ry_count > threshold:
                positions.append((layer + 0.75, q, "RY", ry_count))
    return positions


def get_high_count_mask(intensity, threshold=80):
    return intensity > threshold


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet18SelfAttentionPredictor(in_channels=3, out_channels=1, num_heads=4).to(device)
    model_save_path = "Saved_model/trained_resnet_attention.pth"
    output_dir = "Frozenmask_90"
    os.makedirs(output_dir, exist_ok=True)
    if not os.path.exists(model_save_path):
        print(f"Model file '{model_save_path}' not found. Please check the path!")
        return
    model.load_state_dict(torch.load(model_save_path, map_location=device))
    threshold = 90
    for n_qubits in range(5, 16):
        for n_layers in range(5, 16):
            print(f"\n=== Processing n_qubits={n_qubits}, n_layers={n_layers}, threshold={threshold} ===")
            intensity_map = predict_intensity(model, n_qubits=n_qubits, n_layers=n_layers, variant=0, device=device)
            circuit_filename = os.path.join(output_dir, f"predicted_circuit_qubits{n_qubits}_layers{n_layers}_{threshold}.png")
            draw_intensity_circuit(intensity_map, n_qubits, n_layers, title="Predicted Circuit", vmin=0, vmax=None, show_cz=True, save_filename=circuit_filename)
            high_count_positions = get_high_count_positions(intensity_map, threshold=threshold, n_qubits=n_qubits)
            high_count_mask = get_high_count_mask(intensity_map, threshold=threshold)
            nodes_filename = os.path.join(output_dir, f"high_count_nodes_qubits{n_qubits}_layers{n_layers}_{threshold}.npy")
            flat_mask = high_count_mask.flatten()
            np.save(nodes_filename, flat_mask)
            num_true = np.count_nonzero(flat_mask)
            total = flat_mask.size
            percent = 100.0 * num_true / total
            print(f"Length of flattened mask: {len(flat_mask)}")
            print(f"Number of True in mask: {num_true}")
            print(f"Percentage of parameters masked: {percent:.2f}%")
            highcount_fig_filename = os.path.join(output_dir, f"high_count_nodes_qubits{n_qubits}_layers{n_layers}_{threshold}.png")
            draw_high_count_circuit(high_count_positions, n_qubits, n_layers, title="High Count Nodes", show_cz=True, save_filename=highcount_fig_filename)


if __name__ == "__main__":
    main()
